/*
 * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <float.h>

#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/quantTypeUtils.cuh"
#include "tensorrt_llm/common/reduceKernelUtils.cuh"
#include "tensorrt_llm/kernels/quantization.h"

using namespace tensorrt_llm::common;

namespace tensorrt_llm {
namespace kernels {

__global__ static void quantizedKernel(char4* dst, float4 const* src, int64_t const sizeDiv4,
                                       float const* scalePtr) {
  for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < sizeDiv4;
       idx += blockDim.x * gridDim.x) {
    float const scale = __ldg(scalePtr);
    char4 tmp;
    float4 const floatTmp = __ldg(src + idx);
    tmp.x = cuda_cast<int8_t>(floatTmp.x * scale);
    tmp.y = cuda_cast<int8_t>(floatTmp.y * scale);
    tmp.z = cuda_cast<int8_t>(floatTmp.z * scale);
    tmp.w = cuda_cast<int8_t>(floatTmp.w * scale);
    dst[idx] = tmp;
  }
}

__global__ static void quantizedKernel(char4* dst, half2 const* src, int64_t const sizeDiv4,
                                       float const* scalePtr) {
  for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < sizeDiv4;
       idx += blockDim.x * gridDim.x) {
    float const scale = __ldg(scalePtr);
    char4 tmp;
    int srcId = idx << 1;

    uint2 const h2 = __ldg(reinterpret_cast<uint2 const*>(src + srcId));

    half2 const half2Tmp = reinterpret_cast<half2 const&>(h2.x);
    half2 const half2Tmp2 = reinterpret_cast<half2 const&>(h2.y);

    tmp.x = cuda_cast<int8_t>(cuda_cast<float>(half2Tmp.x) * scale);
    tmp.y = cuda_cast<int8_t>(cuda_cast<float>(half2Tmp.y) * scale);
    tmp.z = cuda_cast<int8_t>(cuda_cast<float>(half2Tmp2.x) * scale);
    tmp.w = cuda_cast<int8_t>(cuda_cast<float>(half2Tmp2.y) * scale);
    dst[idx] = tmp;
  }
}

#ifdef ENABLE_BF16
__global__ static void quantizedKernel(char4* dst, __nv_bfloat162 const* src,
                                       int64_t const sizeDiv4, float const* scalePtr) {
  for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < sizeDiv4;
       idx += blockDim.x * gridDim.x) {
    float const scale = __ldg(scalePtr);
    char4 tmp;
    int srcId = idx << 1;

    uint2 const h2 = __ldg(reinterpret_cast<uint2 const*>(src + srcId));

    __nv_bfloat162 const bfloat162Tmp = reinterpret_cast<__nv_bfloat162 const&>(h2.x);
    __nv_bfloat162 const bfloat162Tmp2 = reinterpret_cast<__nv_bfloat162 const&>(h2.y);

    tmp.x = cuda_cast<int8_t>(cuda_cast<float>(bfloat162Tmp.x) * scale);
    tmp.y = cuda_cast<int8_t>(cuda_cast<float>(bfloat162Tmp.y) * scale);
    tmp.z = cuda_cast<int8_t>(cuda_cast<float>(bfloat162Tmp2.x) * scale);
    tmp.w = cuda_cast<int8_t>(cuda_cast<float>(bfloat162Tmp2.y) * scale);

    dst[idx] = tmp;
  }
}
#endif

////////////////////////////////////////////////////////////////////////////////////////////////////

template <typename T, int NUM_ELTS>
struct DstVec {
  static_assert("not implemented.");
};

template <>
struct DstVec<float2, 2> {
  using Type = uint32_t;
};

template <>
struct DstVec<half2, 4> {
  using Type = uint2;
};

#ifdef ENABLE_BF16

template <>
struct DstVec<__nv_bfloat162, 4> {
  using Type = uint2;
};

#endif  // ENABLE_BF16

template <typename T>
struct DstVec<T, 4> {
  static_assert(sizeof(T) == 4, "not implemented.");
  using Type = uint32_t;
};

template <typename T>
struct DstVec<T, 8> {
  static_assert(sizeof(T) == 2, "not implemented.");
  using Type = uint2;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

// Helper function of getting the absMax of all elements in the vector after clamping.
// Pack two elements in order to use possible hmax2 instructions.
template <typename T>
inline __device__ void clampAndAbsMax(T& localMax, uint4& vec, T const clampMin, T const clampMax) {
  static constexpr int NUM_ELTS = sizeof(uint4) / sizeof(T);

#pragma unroll
  for (int i = 0; i < NUM_ELTS; ++i) {
    T& val = reinterpret_cast<T*>(&vec)[i];
    val = cuda_clamp(val, clampMin, clampMax);
    localMax = cuda_max(localMax, cuda_abs(val));
  }
}

// Helper function of quantizing the vector and storing it to global memory.
// Pack two elements in order to use fast convert instructions.
template <typename T, typename QuantT, bool USE_SMEM>
inline __device__ void quantizeAndStore(QuantT* dstPtr, uint4 vec, T const clampMin,
                                        T const clampMax, float const scaleOrigQuant) {
  static constexpr int NUM_ELTS = sizeof(uint4) / sizeof(T);

  using DstVecType = typename DstVec<T, NUM_ELTS>::Type;
  DstVecType dstVec;
#pragma unroll
  for (int i = 0; i < NUM_ELTS; ++i) {
    T val = reinterpret_cast<T*>(&vec)[i];
    // Values loaded from smem has already been clamped.
    if constexpr (!USE_SMEM) {
      val = cuda_clamp(val, clampMin, clampMax);
    }
    float2 val2 = cuda_cast<float2>(val);
    val2.x *= scaleOrigQuant;
    val2.y *= scaleOrigQuant;
    QuantT quantVal = cuda_cast<QuantT>(val2);
    reinterpret_cast<QuantT*>(&dstVec)[i] = quantVal;
  }
  // Store to destination buffer.
  *reinterpret_cast<DstVecType*>(dstPtr) = dstVec;
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template <typename T, typename QuantT, bool USE_SMEM>
__global__ void perTokenQuantization(QuantT* dst, T const* src, int64_t const numRows,
                                     int64_t const numCols, float const* clampPtr, float* scalePtr,
                                     float* sumPtr, bool hasFp8MinScaling) {
  // Smem buffer.
  extern __shared__ uint4 smemBuffer[];

  // The clamping minimum / maximum values.
  T const clampMin = cuda_cast<T>(clampPtr ? clampPtr[0] : -FLT_MAX);
  T const clampMax = cuda_cast<T>(clampPtr ? clampPtr[1] : FLT_MAX);

  // Pack two elements in order to use higher through instructions.
  using T2 = typename packed_as<T, 2>::type;
  using QuantT2 = typename packed_as<QuantT, 2>::type;
  T2 const clampMin2 = cuda_cast<T2, T>(clampMin);
  T2 const clampMax2 = cuda_cast<T2, T>(clampMax);

  // The quantized data type's maximum value (upper-bound).
  static constexpr float MAX_QUANT_VAL = QuantTypeStaticVals<QuantT>::MAX_VAL;
  // The minimum scaling factor (lower-bound).
  static constexpr float MIN_SCALING_FACTOR = QuantTypeStaticVals<QuantT>::MIN_SCALING_FACTOR;
  static constexpr float MIN_SCALING_FACTOR_RCP =
      QuantTypeStaticVals<QuantT>::MIN_SCALING_FACTOR_RCP;

  // The number of elements in the packed uint4 vec.
  static constexpr int NUM_ELTS_PER_VEC = sizeof(uint4) / sizeof(T);
  static constexpr int NUM_ELTS2_PER_VEC = sizeof(uint4) / sizeof(T2);

  // The number of vectors in the column.
  int const numColVecs = numCols / NUM_ELTS_PER_VEC;
  // The vector pointers for src.
  uint4 const* srcVec = reinterpret_cast<uint4 const*>(src) + blockIdx.x * numColVecs;
  // The pointer for dst.
  QuantT* dstRow = dst + blockIdx.x * numCols;
  // T const* srcRow = src + blockIdx.x * numCols;

  T2 localMax2 = cuda_cast<T2, T>(T(1e-6f));
  float2 localSum2 = {0.f, 0.f};

  for (int i = threadIdx.x; i < numColVecs; i += blockDim.x) {
    uint4 vec = srcVec[i];

#pragma unroll
    for (int j = 0; j < NUM_ELTS2_PER_VEC; ++j) {
      T2& val2 = reinterpret_cast<T2*>(&vec)[j];
      val2 = cuda_clamp(val2, clampMin2, clampMax2);
      localMax2 = cuda_max(localMax2, cuda_abs(val2));
      // TODO: template the version that requires sum to avoid dynamic branching.
      if (sumPtr != nullptr) {
        localSum2.x += cuda_cast<float>(val2.x);
        localSum2.y += cuda_cast<float>(val2.y);
      }
    }
    // Avoid reloading from global memory.
    if constexpr (USE_SMEM) {
      smemBuffer[i] = vec;
    }
  }
  float const rowMax = blockAllReduceMax(cuda_cast<float>(cuda_max<T, T2>(localMax2)));
  if (threadIdx.x == 0) {
    scalePtr[blockIdx.x] = hasFp8MinScaling ? cuda_max(rowMax / MAX_QUANT_VAL, MIN_SCALING_FACTOR)
                                            : (rowMax / MAX_QUANT_VAL);
  }

  if (sumPtr != nullptr) {
    float rowSum[1] = {cuda_sum<float>(localSum2)};
    blockReduceSumV2<float, 1>(rowSum);
    if (threadIdx.x == 0) {
      sumPtr[blockIdx.x] = rowSum[0];
    }
  }

  float const scaleOrigQuant = hasFp8MinScaling
                                   ? fminf(MAX_QUANT_VAL / rowMax, MIN_SCALING_FACTOR_RCP)
                                   : MAX_QUANT_VAL / rowMax;
  for (int i = threadIdx.x; i < numColVecs; i += blockDim.x) {
    uint4 vec = USE_SMEM ? smemBuffer[i] : srcVec[i];
    QuantT2* dstPtr = reinterpret_cast<QuantT2*>(dstRow + i * NUM_ELTS_PER_VEC);
    quantizeAndStore<T2, QuantT2, USE_SMEM>(dstPtr, vec, clampMin2, clampMax2, scaleOrigQuant);
  }
}

////////////////////////////////////////////////////////////////////////////////////////////////////
// FP4/MXFP8 Quantization

constexpr int CVT_FP4_ELTS_PER_THREAD = 8;
constexpr int CVT_FP4_SF_VEC_SIZE = 16;
constexpr int CVT_ELTS_PER_THREAD = 8;
constexpr int CVT_FP4_THREADS_PER_WARP = 32;
constexpr int CVT_FP8_TO_FP4_ELTS_PER_THREAD = 16;

// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
  uint32_t val;
  asm volatile(
      "{\n"
      ".reg .b8 byte0;\n"
      ".reg .b8 byte1;\n"
      ".reg .b8 byte2;\n"
      ".reg .b8 byte3;\n"
      "cvt.rn.satfinite.e2m1x2.f32   byte0, %2, %1;\n"
      "cvt.rn.satfinite.e2m1x2.f32   byte1, %4, %3;\n"
      "cvt.rn.satfinite.e2m1x2.f32   byte2, %6, %5;\n"
      "cvt.rn.satfinite.e2m1x2.f32   byte3, %8, %7;\n"
      "mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
      "}"
      : "=r"(val)
      : "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]), "f"(array[4]), "f"(array[5]),
        "f"(array[6]), "f"(array[7]));
  return val;
#else
  // static_assert(false, "not supported.");
  return 0;
#endif
}

// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
  uint32_t val;
  asm volatile(
      "{\n"
      ".reg .b8 byte0;\n"
      ".reg .b8 byte1;\n"
      ".reg .b8 byte2;\n"
      ".reg .b8 byte3;\n"
      "cvt.rn.satfinite.e2m1x2.f32   byte0, %2, %1;\n"
      "cvt.rn.satfinite.e2m1x2.f32   byte1, %4, %3;\n"
      "cvt.rn.satfinite.e2m1x2.f32   byte2, %6, %5;\n"
      "cvt.rn.satfinite.e2m1x2.f32   byte3, %8, %7;\n"
      "mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
      "}"
      : "=r"(val)
      : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), "f"(array[2].x),
        "f"(array[2].y), "f"(array[3].x), "f"(array[3].y));
  return val;
#else
  // static_assert(false, "not supported.");
  return 0;
#endif
}

// Convert 8 float2 values into 16 e2m1 values (represented as one uint64_t).
inline __device__ uint64_t fp32_vec_to_e2m1(float2 (&array)[8]) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
  uint64_t val;
  asm volatile(
      "{\n"
      ".reg .b8 byte0;\n"
      ".reg .b8 byte1;\n"
      ".reg .b8 byte2;\n"
      ".reg .b8 byte3;\n"
      ".reg .b8 byte4;\n"
      ".reg .b8 byte5;\n"
      ".reg .b8 byte6;\n"
      ".reg .b8 byte7;\n"
      ".reg .b32 val0;\n"
      ".reg .b32 val1;\n"
      "cvt.rn.satfinite.e2m1x2.f32   byte0,  %2,  %1;\n"
      "cvt.rn.satfinite.e2m1x2.f32   byte1,  %4,  %3;\n"
      "cvt.rn.satfinite.e2m1x2.f32   byte2,  %6,  %5;\n"
      "cvt.rn.satfinite.e2m1x2.f32   byte3,  %8,  %7;\n"
      "cvt.rn.satfinite.e2m1x2.f32   byte4, %10,  %9;\n"
      "cvt.rn.satfinite.e2m1x2.f32   byte5, %12, %11;\n"
      "cvt.rn.satfinite.e2m1x2.f32   byte6, %14, %13;\n"
      "cvt.rn.satfinite.e2m1x2.f32   byte7, %16, %15;\n"
      "mov.b32 val0, {byte0, byte1, byte2, byte3};\n"
      "mov.b32 val1, {byte4, byte5, byte6, byte7};\n"
      "mov.b64 %0, {val0, val1};\n"
      "}"
      : "=l"(val)
      : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), "f"(array[2].x),
        "f"(array[2].y), "f"(array[3].x), "f"(array[3].y), "f"(array[4].x), "f"(array[4].y),
        "f"(array[5].x), "f"(array[5].y), "f"(array[6].x), "f"(array[6].y), "f"(array[7].x),
        "f"(array[7].y));
  return val;
#else
  // static_assert(false, "not supported.");
  return 0;
#endif
}

// Convert 4 float2 values into 8 e4m3 values (represented as one uint64_t).
inline __device__ uint64_t fp32_vec_to_e4m3(float2 (&array)[4]) {
  union {
    uint64_t val;
    __nv_fp8x2_e4m3 elts[4];
  } u;

  static_assert(sizeof(u.val) == sizeof(u.elts),
                "Expected to alias uint64_t and __nv_fp8x2_e4m3[4]");

  u.elts[0] = __nv_fp8x2_e4m3(array[0]);
  u.elts[1] = __nv_fp8x2_e4m3(array[1]);
  u.elts[2] = __nv_fp8x2_e4m3(array[2]);
  u.elts[3] = __nv_fp8x2_e4m3(array[3]);
  return u.val;
}

// Fast reciprocal.
inline __device__ float reciprocal_approximate_ftz(float a) {
  float b;
  asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a));
  return b;
}

__device__ __forceinline__ float exp2f_rcp(uint8_t exp) {
  constexpr uint32_t FP32_EXPONENT_BIAS = 127;
  return (exp == 0) ? 1 : exp2f(FP32_EXPONENT_BIAS - static_cast<float>(exp));
}

// Define a 16 bytes packed data type.
template <class Type>
struct PackedVec {
  typename TypeConverter<Type>::Type elts[4];
  static_assert(sizeof(elts) == sizeof(Type) * CVT_ELTS_PER_THREAD,
                "Vector size should match the number of elements per thread.");
};

template <>
struct PackedVec<__nv_fp8_e4m3> {
  __nv_fp8x2_e4m3 elts[8];
  static_assert(sizeof(elts) == sizeof(__nv_fp8_e4m3) * CVT_FP8_TO_FP4_ELTS_PER_THREAD,
                "Vector size should match the number of elements per thread.");
};

// Quantizes the provided PackedVec into the uint32_t output
template <class Type, int SF_VEC_SIZE, bool UE8M0_SF>
__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal, uint8_t* SFout) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
  // Get absolute maximum values among the local 8 values.
  auto localMax = cuda_abs(vec.elts[0]);

// Local maximum value.
#pragma unroll
  for (int i = 1; i < CVT_ELTS_PER_THREAD / 2; i++) {
    localMax = cuda_max(localMax, cuda_abs(vec.elts[i]));
  }

  constexpr int CVT_NUM_THREADS_PER_SF = SF_VEC_SIZE / CVT_ELTS_PER_THREAD;
  // Get the absolute maximum among all 16 values (two threads for 16, four threads for 32).
  localMax = cuda_max(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax);
  if constexpr (CVT_NUM_THREADS_PER_SF == 4) {
    localMax = cuda_max(__shfl_xor_sync(uint32_t(-1), localMax, 2), localMax);
  }
  // Get the final absolute maximum values.
  float vecMax = float(cuda_max(localMax.x, localMax.y));

  // 8 bits representation of the SF.
  uint8_t fp8SFVal;
  float outputScale;
  // Write the SF to global memory (STG.8).
  if constexpr (UE8M0_SF) {
    __nv_fp8_e8m0 tmp;
    // Scale the max value to the range of E2m1.
    vecMax *= reciprocal_approximate_ftz(6.0f);
    tmp.__x = __nv_cvt_float_to_e8m0(vecMax, __NV_SATFINITE, cudaRoundPosInf);

    fp8SFVal = tmp.__x;
    outputScale = vecMax != 0 ? exp2f_rcp(fp8SFVal) : 0.0f;
  } else {
    // Get the SF (max value of the vector / max value of e2m1).
    // maximum value of e2m1 = 6.0.
    // TODO: use half as compute data type.
    auto SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f));

    // Here SFValue is always positive, so E4M3 is the same as UE4M3.
    __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue);
    fp8SFVal = tmp.__x;
    SFValue = static_cast<float>(tmp);
    // Get the output scale.
    // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal)) * reciprocal(SFScaleVal))
    outputScale = vecMax != 0
                      ? reciprocal_approximate_ftz(SFValue * reciprocal_approximate_ftz(SFScaleVal))
                      : 0.0f;
  }

  if (SFout) {
    // Write the SF to global memory (STG.8).
    *SFout = fp8SFVal;
  }

  // Convert the input to float.
  float2 fp2Vals[CVT_ELTS_PER_THREAD / 2];

#pragma unroll
  for (int i = 0; i < CVT_ELTS_PER_THREAD / 2; i++) {
    if constexpr (std::is_same_v<Type, half>) {
      fp2Vals[i] = __half22float2(vec.elts[i]);
    } else {
      fp2Vals[i] = __bfloat1622float2(vec.elts[i]);
    }
    fp2Vals[i].x *= outputScale;
    fp2Vals[i].y *= outputScale;
  }

  // Convert to e2m1 values.
  uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals);

  // Write the e2m1 values to global memory.
  return e2m1Vec;
#else
  return 0;
#endif
}

template <class Type, int SF_VEC_SIZE, bool UE8M0_SF>
__device__ uint64_t cvt_warp_fp8_to_fp4(PackedVec<Type>& vec, float SFScaleVal, uint8_t* SFout) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)

  float const dequant_to_fp16_scale = 6.f * reciprocal_approximate_ftz(SFScaleVal);

  // Dequant fp8 to fp16
  __half2 vec_half2[8];
#pragma unroll
  for (int i = 0; i < CVT_FP8_TO_FP4_ELTS_PER_THREAD / 2; i++) {
    float2 tmp = static_cast<float2>(vec.elts[i]);
    tmp.x *= dequant_to_fp16_scale;
    tmp.y *= dequant_to_fp16_scale;
    vec_half2[i] = __float22half2_rn(tmp);
  }

  // Get absolute maximum values among the local 8 values.
  auto localMax = __habs2(vec_half2[0]);
  // Local maximum value.
#pragma unroll
  for (int i = 1; i < CVT_FP8_TO_FP4_ELTS_PER_THREAD / 2; i++) {
    localMax = __hmax2(localMax, __habs2(vec_half2[i]));
  }

  constexpr int CVT_NUM_THREADS_PER_SF = SF_VEC_SIZE / CVT_FP8_TO_FP4_ELTS_PER_THREAD;
  if constexpr (CVT_NUM_THREADS_PER_SF == 2) {
    // For block 32, we need to reduce the local max across two threads.
    localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax);
  }

  // Get the final absolute maximum values.
  float vecMax = float(__hmax(localMax.x, localMax.y));

  // Get the SF (max value of the vector / max value of e2m1).
  // maximum value of e2m1 = 6.0.
  // TODO: use half as compute data type.
  float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f));
  float SFValueNarrow;
  // 8 bits representation of the SF.
  uint8_t fp8SFVal;
  // Write the SF to global memory (STG.8).
  if constexpr (UE8M0_SF) {
    __nv_fp8_e8m0 tmp;
    tmp.__x = __nv_cvt_float_to_e8m0(SFValue, __NV_SATFINITE, cudaRoundPosInf);
    SFValueNarrow = static_cast<float>(tmp);
    fp8SFVal = tmp.__x;
  } else {
    // Here SFValue is always positive, so E4M3 is the same as UE4M3.
    __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue);
    fp8SFVal = tmp.__x;
    SFValueNarrow = static_cast<float>(tmp);
  }
  // Get the output scale.
  // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * reciprocal(SFScaleVal))
  float outputScale = SFValue != 0 ? SFScaleVal * reciprocal_approximate_ftz(SFValueNarrow) : 0.0f;

  if (SFout) {
    // Write the SF to global memory (STG.8).
    *SFout = fp8SFVal;
  }

  // Convert the input to float.
  float2 fp2Vals[CVT_FP8_TO_FP4_ELTS_PER_THREAD / 2];

#pragma unroll
  for (int i = 0; i < CVT_FP8_TO_FP4_ELTS_PER_THREAD / 2; i++) {
    fp2Vals[i] = __half22float2(vec_half2[i]);
    fp2Vals[i].x *= outputScale;
    fp2Vals[i].y *= outputScale;
  }

  // Convert to e2m1 values.
  uint64_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals);

  // Write the e2m1 values to global memory.
  return e2m1Vec;
#else
  return 0;
#endif
}

// Quantizes the provided PackedVec into the uint64_t output
template <class Type, int SF_VEC_SIZE>
__device__ uint64_t cvt_warp_fp16_to_mxfp8(PackedVec<Type>& vec, uint8_t* SFout) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
  // Get absolute maximum values among the local 8 values.
  auto localMax = cuda_abs(vec.elts[0]);

// Local maximum value.
#pragma unroll
  for (int i = 1; i < CVT_ELTS_PER_THREAD / 2; i++) {
    localMax = cuda_max(localMax, cuda_abs(vec.elts[i]));
  }

  constexpr int CVT_NUM_THREADS_PER_SF = SF_VEC_SIZE / CVT_ELTS_PER_THREAD;
  // Get the absolute maximum among all 16 values (two threads for 16, four threads for 32).
  localMax = cuda_max(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax);
  if constexpr (CVT_NUM_THREADS_PER_SF == 4) {
    localMax = cuda_max(__shfl_xor_sync(uint32_t(-1), localMax, 2), localMax);
  }
  // Get the final absolute maximum values.
  float vecMax = float(cuda_max(localMax.x, localMax.y));

  // Get the SF (max value of the vector / max value of mxfp8).
  float SFValue = vecMax * reciprocal_approximate_ftz(448.0f);
  // 8 bits representation of the SF.
  uint8_t fp8SFVal;
  // Write the SF to global memory (STG.8).
  __nv_fp8_e8m0 tmpSFVal;
  tmpSFVal.__x = __nv_cvt_float_to_e8m0(SFValue, __NV_SATFINITE, cudaRoundPosInf);
  SFValue = static_cast<float>(tmpSFVal);
  fp8SFVal = tmpSFVal.__x;
  // Get the output scale (reciprocal of the SFValue).
  float outputScale = vecMax != 0.f ? reciprocal_approximate_ftz(SFValue) : 0.0f;

  if (SFout) {
    // Write the SF to global memory (STG.8).
    *SFout = fp8SFVal;
  }

  // Convert the input to float.
  float2 fp2Vals[CVT_ELTS_PER_THREAD / 2];

#pragma unroll
  for (int i = 0; i < CVT_ELTS_PER_THREAD / 2; i++) {
    if constexpr (std::is_same_v<Type, half>) {
      fp2Vals[i] = __half22float2(vec.elts[i]);
    } else {
      fp2Vals[i] = __bfloat1622float2(vec.elts[i]);
    }
    fp2Vals[i].x *= outputScale;
    fp2Vals[i].y *= outputScale;
  }

  // Convert to e4m3 values.
  uint64_t e4m3Vec = fp32_vec_to_e4m3(fp2Vals);

  // Write the e4m3 values to global memory.
  return e4m3Vec;
#else
  return 0;
#endif
}

inline __device__ __host__ int64_t get_sf_out_offset_128x4(std::optional<int> batchIdx, int mIdx,
                                                           int kIdx, std::optional<int> numRows,
                                                           int numColVecs) {
  // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
  // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]

  // batched tensor
  // SF layout [numBTiles, numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
  // --> index [bTileIdx, mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]

  int32_t innerKIdx = (kIdx % 4);
  int64_t innerKStride = 1;

  int32_t innerMIdx = (mIdx % (32 * 4)) / 32;
  int64_t innerMStride = 4 * innerKStride;  // 4

  // M tile layout [32, 4] is column-major.
  int32_t outerMIdx = (mIdx % 32);
  int64_t outerMStride = 4 * innerMStride;  // 16

  int32_t kTileIdx = (kIdx / 4);
  int64_t kTileStride = 32 * outerMStride;  // 512

  // SF vector size 16 or 32. We round the "numCols" up to a multiple of 64 or 128.
  // It is the same as rounding the "numColVecs" up to a multiple of 4.
  int32_t numKTiles = (numColVecs + 4 - 1) / 4;

  int32_t mTileIdx = mIdx / (32 * 4);
  int64_t mTileStride = numKTiles * kTileStride;

  // Each SF block has 128 rows so pad rows to the multiple of 128.
  int32_t numMTiles = (numRows.value_or(0) + 128 - 1) / 128;
  int64_t bTileStride = numMTiles * mTileStride;

  // Compute the global offset.
  int64_t SFOffset = batchIdx.value_or(0) * bTileStride + mTileIdx * mTileStride +
                     kTileIdx * kTileStride + outerMIdx * outerMStride + innerMIdx * innerMStride +
                     innerKIdx * innerKStride;

  return SFOffset;
}

inline __device__ __host__ int64_t get_sf_out_offset_8x4(std::optional<int> batchIdx, int mIdx,
                                                         int kIdx, std::optional<int> numRows,
                                                         int numCols) {
  // SF layout [numMTiles, numKTiles, 8 (mTile), 4(kTile)]
  // --> index [mTileIdx, kTileIdx, innerMIdx, innerKIdx]

  // batched tensor
  // SF layout [numBTiles, numMTiles, numKTiles, 8 (mTile), 4(kTile)]
  // --> index [bTileIdx, mTileIdx, kTileIdx, innerMIdx, innerKIdx]
  const int32_t mTile = 8;
  int32_t innerKIdx = (kIdx % 4);
  int64_t innerKStride = 1;

  int32_t innerMIdx = (mIdx % mTile);
  int64_t mStride = 4 * innerKStride;

  int32_t kTileIdx = (kIdx / 4);
  int64_t kTileStride = mTile * mStride;

  int32_t numKTiles = (numCols + 4 - 1) / 4;
  int32_t mTileIdx = mIdx / mTile;
  int64_t mTileStride = numKTiles * kTileStride;

  int32_t numMTiles = (numRows.value_or(0) + 8 - 1) / 8;
  int64_t bTileStride = numMTiles * mTileStride;

  int64_t SFOffset = batchIdx.value_or(0) * bTileStride + mTileIdx * mTileStride +
                     kTileIdx * kTileStride + innerMIdx * mStride + innerKIdx * innerKStride;

  return SFOffset;
}

template <class SFType, int CVT_NUM_THREADS_PER_SF>
__device__ uint8_t* cvt_quant_get_sf_out_offset(std::optional<int> batchIdx, int rowIdx,
                                                int colVecIdx, std::optional<int> numRows,
                                                int numColVecs, SFType* SFout,
                                                QuantizationSFLayout layout) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
  static_assert(CVT_NUM_THREADS_PER_SF == 1 || CVT_NUM_THREADS_PER_SF == 2 ||
                CVT_NUM_THREADS_PER_SF == 4);

  // One pair of threads write one SF to global memory.
  // TODO: stage through smem for packed STG.32
  // is it better than STG.8 from 4 threads ?
  if (threadIdx.x % CVT_NUM_THREADS_PER_SF == 0) {
    if (layout == QuantizationSFLayout::SWIZZLED_128x4 ||
        layout == QuantizationSFLayout::SWIZZLED_8x4) {
      // SF vector index (16 elements share one SF in the K dimension).
      // numRows and numCols are unpadded.
      int32_t kIdx = colVecIdx / CVT_NUM_THREADS_PER_SF;
      int32_t mIdx = rowIdx;

      auto SFOffset = layout == QuantizationSFLayout::SWIZZLED_128x4
                          ? get_sf_out_offset_128x4(batchIdx, mIdx, kIdx, numRows, numColVecs)
                          : get_sf_out_offset_8x4(batchIdx, mIdx, kIdx, numRows, numColVecs);
      return reinterpret_cast<uint8_t*>(SFout) + SFOffset;
    } else if (layout == QuantizationSFLayout::LINEAR) {
      // Linear row-major layout, no padding required.
      int32_t KTileIdx = colVecIdx / CVT_NUM_THREADS_PER_SF;

      int32_t numKTiles = numColVecs;
      int64_t mTileStride = numKTiles;

      int64_t BTileStride = numRows.value_or(0) * mTileStride;

      int64_t SFOffset = batchIdx.value_or(0) * BTileStride + rowIdx * mTileStride + KTileIdx;
      return reinterpret_cast<uint8_t*>(SFout) + SFOffset;
    } else {
      return nullptr;
    }
  }
#endif
  return nullptr;
}

template <BlockScaleQuantizationType quantization_type, class Type, int SF_VEC_SIZE, bool UE8M0_SF>
__global__ void
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__launch_bounds__(512, 4) quantize_with_block_size(
#else
quantize_with_block_size(
#endif
    int32_t numbatches, int32_t numRows, int32_t numCols, int32_t numPaddedCols, Type const* in,
    float const* SFScale, uint32_t* out, uint32_t* SFout, QuantizationSFLayout layout) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)

  // The elements per thread.
  static constexpr int ELTS_PER_THREAD = quantization_type == BlockScaleQuantizationType::FP8_TO_FP4
                                             ? CVT_FP8_TO_FP4_ELTS_PER_THREAD
                                             : CVT_ELTS_PER_THREAD;

  using PackedVec = PackedVec<Type>;
  static constexpr int CVT_NUM_THREADS_PER_SF = SF_VEC_SIZE / ELTS_PER_THREAD;  // 2 or 4
  static_assert(sizeof(PackedVec) == sizeof(Type) * ELTS_PER_THREAD, "Vec size is not matched.");

  // Get the global scaling factor, which will be applied to the SF.
  // Note SFScale is the same as next GEMM's alpha, which is (448.f / (Alpha_A / 6.f)).
  float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0];

  // Is it swizzled layout?
  bool isSfSwizzledLayout = layout == QuantizationSFLayout::SWIZZLED_128x4 ||
                            layout == QuantizationSFLayout::SWIZZLED_8x4;

  // The number of padded rows considering 128x4 SF layout.
  int numPaddedRowsForSf = isSfSwizzledLayout ? PadUpFn(numRows, 128) : numRows;
  int numColsForSf = isSfSwizzledLayout ? PadUpFn(numPaddedCols, 4 * SF_VEC_SIZE) : numPaddedCols;

  // The number of threads in the column dimension。
  // Note that numCols/numPaddedCols/numColsForSf are guaranteed to be multiples of ELTS_PER_THREAD.
  int numColThreads = numCols / ELTS_PER_THREAD;
  int numPaddedColThreads = numPaddedCols / ELTS_PER_THREAD;
  int numColThreadsForSf = numColsForSf / ELTS_PER_THREAD;

  asm volatile("griddepcontrol.wait;");
  // Input tensor batch/row/col loops.
  for (int rowIdx = blockIdx.x; rowIdx < numPaddedRowsForSf; rowIdx += gridDim.x) {
    for (int batchIdx = 0; batchIdx < numbatches; batchIdx++) {
      for (int colIdx = threadIdx.x; colIdx < numColThreadsForSf; colIdx += blockDim.x) {
        std::optional<int> optionalBatchIdx = batchIdx;
        std::optional<int> optionalNumRows = numRows;

        // The SF output pointer.
        auto sf_out = cvt_quant_get_sf_out_offset<uint32_t, CVT_NUM_THREADS_PER_SF>(
            optionalBatchIdx, rowIdx, colIdx, optionalNumRows, numPaddedCols / SF_VEC_SIZE, SFout,
            layout);

        // The input tensor offset.
        int64_t inOffset =
            static_cast<int64_t>(batchIdx * numRows + rowIdx) * numColThreads + colIdx;
        int64_t outOffset =
            static_cast<int64_t>(batchIdx * numRows + rowIdx) * numPaddedColThreads + colIdx;

        // Set the values to 0 of those are padded columns.
        if (rowIdx < numRows && colIdx >= numColThreads && colIdx < numPaddedColThreads) {
          // Dispatch the quantization kernel.
          if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4) {
            reinterpret_cast<uint32_t*>(out)[outOffset] = 0u;
          } else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4 ||
                               quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8) {
            reinterpret_cast<uint64_t*>(out)[outOffset] = 0ull;
          }
        }

        // Set the SF padding to 0.
        if (rowIdx >= numRows || colIdx >= numColThreads) {
          // Set the SF padding to 0.
          if (sf_out != nullptr) {
            sf_out[0] = 0x00;
          }
        } else {
          // Load the input vector.
          PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];

          // Dispatch the quantization kernel.
          if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4) {
            reinterpret_cast<uint32_t*>(out)[outOffset] =
                cvt_warp_fp16_to_fp4<Type, SF_VEC_SIZE, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
          } else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4) {
            reinterpret_cast<uint64_t*>(out)[outOffset] =
                cvt_warp_fp8_to_fp4<__nv_fp8_e4m3, SF_VEC_SIZE, UE8M0_SF>(in_vec, SFScaleVal,
                                                                          sf_out);
          } else if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8) {
            reinterpret_cast<uint64_t*>(out)[outOffset] =
                cvt_warp_fp16_to_mxfp8<Type, SF_VEC_SIZE>(in_vec, sf_out);
          }
        }
      }
    }
  }
  asm volatile("griddepcontrol.launch_dependents;");
#endif
}

__global__ void block_scale_interleave_kernel(int numbatches, int numRows, int numCols,
                                              uint8_t const* SFIn, uint8_t* SFOutput);
}  // namespace kernels
}  // namespace tensorrt_llm
